import copy
import random
from tqdm import tqdm
import numpy as np
import os
import logging
from datetime import datetime
import json
import sys
import torch
from nats_bench import create
from typing import Dict, List, Optional, Tuple
import time

__all__ = ["EvolutionFinderNASBench201"]


class ArchManager:
    def __init__(self):
        self.operations = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
        self.num_ops = len(self.operations)
        self.num_edges = 6

    def random_sample(self):
        edge_ops = [random.randint(0, self.num_ops - 1) for _ in range(self.num_edges)]
        return self._ops_to_arch_str(edge_ops)

    def mutate_sample(self, arch_str, mutate_prob=0.1):
        ops = self._arch_str_to_ops(arch_str)
        for i in range(self.num_edges):
            if random.random() < mutate_prob:
                ops[i] = random.randint(0, self.num_ops - 1)
        return self._ops_to_arch_str(ops)

    def crossover_sample(self, arch_str1, arch_str2):
        ops1 = self._arch_str_to_ops(arch_str1)
        ops2 = self._arch_str_to_ops(arch_str2)
        child_ops = [random.choice([ops1[i], ops2[i]]) for i in range(self.num_edges)]
        return self._ops_to_arch_str(child_ops)

    def _arch_str_to_ops(self, arch_str):
        ops = []
        components = arch_str.split('+')
        
        edge = components[0].strip('|')
        op_name = edge.split('~')[0]
        ops.append(self.operations.index(op_name))
        
        edges = components[1].strip('|').split('|')
        for edge in edges:
            op_name = edge.split('~')[0]
            ops.append(self.operations.index(op_name))
        
        edges = components[2].strip('|').split('|')
        for edge in edges:
            op_name = edge.split('~')[0]
            ops.append(self.operations.index(op_name))
        
        return ops

    def _ops_to_arch_str(self, ops):
        if len(ops) != self.num_edges:
            raise ValueError(f"Expected {self.num_edges} operations, got {len(ops)}")
        
        arch_str = f"|{self.operations[ops[0]]}~0|+|{self.operations[ops[1]]}~0|{self.operations[ops[2]]}~1|+|{self.operations[ops[3]]}~0|{self.operations[ops[4]]}~1|{self.operations[ops[5]]}~2|"
        return arch_str


class AccuracyPredictor:
    def __init__(self, dataset='cifar10', metric='test'):
        try:
            self.api = create(None, 'tss', fast_mode=True, verbose=False)
            self.dataset = dataset
            self.metric = metric

            print(f"Successfully loaded NATS-Bench TSS API for {dataset} with {metric} metric")
        except Exception as e:
            print(f"Unable to load NATS-Bench TSS API: {e}")
            print("Please ensure you have downloaded the NATS-Bench TSS data file and set the correct path")
            sys.exit(1)

    def predict_accuracy(self, arch):
        try:
            arch_index = self.api.query_index_by_arch(arch)
            if self.dataset == 'cifar10' and self.metric == 'valid':
                results = self.api.get_more_info(arch_index, 'cifar10-valid', hp='200', is_random=False)
            else:
                results = self.api.get_more_info(arch_index, self.dataset, hp='200', is_random=False)
                
            if self.metric == 'valid':
                return results['valid-accuracy']
            else:
                return results['test-accuracy']
        except Exception as e:
            print(f"Error querying architecture performance: {e}")
            return 0.0


class EvolutionFinderNASBench201:
    def __init__(
        self,
        dataset='cifar10',
        logger=None,
        **kwargs
    ):
        self.dataset = dataset
        self.arch_manager = ArchManager()
        self.metric = kwargs.get("metric", "test")
        self.accuracy_predictor = AccuracyPredictor(dataset, self.metric)
        self.logger = logger

        self.mutate_prob = kwargs.get("mutate_prob", 0.1)
        self.population_size = kwargs.get("population_size", 50)
        self.max_time_budget = kwargs.get("max_time_budget", 20)
        self.parent_ratio = kwargs.get("parent_ratio", 0.2)
        self.mutation_ratio = kwargs.get("mutation_ratio", 0.5)
        self.seed = kwargs.get("seed", 0)
        
        self.visited = set()
        self.arch_performances = {}
        self.max_attempts = 10
        
        self.total_explored = 0
        self.best_arch = None
        self.best_acc = 0.0
        self.best_found_at = 0
        
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
            print(f"Random seed set to: {self.seed}")

    def evaluate_arch(self, arch_str):
        if arch_str not in self.arch_performances:
            acc = self.accuracy_predictor.predict_accuracy(arch_str)
            self.arch_performances[arch_str] = acc
        else:
            acc = self.arch_performances[arch_str]
            
        self.total_explored += 1
        
        if acc > self.best_acc:
            self.best_acc = acc
            self.best_arch = arch_str
            self.best_found_at = self.total_explored
            self.logger.info(f"New best architecture found! Acc: {acc:.4f}, Explored: {self.total_explored}")
            self.logger.info(f"Architecture: {arch_str}")
                
        return acc

    def generate_initial_population(self):
        # Generate random architectures to form the initial population
        population = []
        
        while len(population) < self.population_size:
            arch = self.arch_manager.random_sample()
            if arch not in self.visited:
                self.visited.add(arch)
                acc = self.evaluate_arch(arch)
                population.append([acc, arch])
                
        return population

    def mutate_sample(self, sample):
        attempts = 0
        max_attempts = self.max_attempts * 5
        
        while attempts < max_attempts:
            new_sample = self.arch_manager.mutate_sample(sample, self.mutate_prob)
            attempts += 1
            
            if new_sample not in self.visited:
                return new_sample
        
        self.logger.warning(f"Failed to search new architecture by mutation after {max_attempts} attempts")
        return self.random_unique_sample()

    def crossover_sample(self, sample1, sample2):
        attempts = 0
        max_attempts = self.max_attempts * 5
        
        while attempts < max_attempts:
            new_sample = self.arch_manager.crossover_sample(sample1, sample2)
            attempts += 1
            
            if new_sample not in self.visited:
                return new_sample
        
        self.logger.warning(f"Failed to search new architecture by crossover after {max_attempts} attempts")
        return self.random_unique_sample()
        
    def random_unique_sample(self):
        attempts = 0
        max_attempts = 100
        
        while attempts < max_attempts:
            sample = self.arch_manager.random_sample()
            attempts += 1
            if sample not in self.visited:
                return sample
    
        return sample

    def run_evolution_search(self):
        max_time_budget = self.max_time_budget
        population_size = self.population_size
        mutation_numbers = int(round(self.mutation_ratio * population_size))
        parents_size = int(round(self.parent_ratio * population_size))
        
        self.logger.info(f"Starting evolution search for NAS-Bench-201 on {self.dataset} with {self.metric} metric...")
        self.logger.info(f"Population: {population_size}, Parents: {parents_size}, Mutations: {mutation_numbers}")
        
        # Initialize the first generation with random architectures
        population = self.generate_initial_population()
        best_valids = [self.best_acc]
        
        for iter in tqdm(range(max_time_budget)):
            # Select top performing architectures as parents
            parents = sorted(population, key=lambda x: x[0], reverse=True)[:parents_size]
            current_best_acc = parents[0][0]
            
            self.logger.info(f"\nIteration {iter+1}/{max_time_budget}:")
            self.logger.info(f"Current best acc: {current_best_acc:.4f}, Overall best: {self.best_acc:.4f}")
            self.logger.info(f"Explored architectures: {self.total_explored}")
            
            best_valids.append(self.best_acc)
            
            new_population = []
            
            mutation_count = crossover_count = random_count = 0
            
            # Generate new architectures through mutation
            mutation_attempts = 0
            while len(new_population) < mutation_numbers and mutation_attempts < mutation_numbers * self.max_attempts:
                parent_idx = random.randint(0, parents_size - 1)
                parent_arch = parents[parent_idx][1]
                
                new_arch = self.mutate_sample(parent_arch)
                mutation_attempts += 1
                
                if new_arch not in self.visited:
                    self.visited.add(new_arch)
                    acc = self.evaluate_arch(new_arch)
                    new_population.append([acc, new_arch])
                    mutation_count += 1
            
            # Generate new architectures through crossover
            crossover_attempts = 0
            needed_crossovers = population_size - len(new_population)
            while len(new_population) < population_size and crossover_attempts < needed_crossovers * self.max_attempts:
                parent1_idx = random.randint(0, parents_size - 1)
                parent2_idx = random.randint(0, parents_size - 1)
                parent1_arch = parents[parent1_idx][1]
                parent2_arch = parents[parent2_idx][1]
                
                new_arch = self.crossover_sample(parent1_arch, parent2_arch)
                crossover_attempts += 1
                
                if new_arch not in self.visited:
                    self.visited.add(new_arch)
                    acc = self.evaluate_arch(new_arch)
                    new_population.append([acc, new_arch])
                    crossover_count += 1
            
            # If needed, supplement population with random architectures
            random_needed = population_size - len(new_population)
            if random_needed > 0:
                self.logger.info(f"Need to supplement {random_needed} architectures through random sampling")
                for _ in range(random_needed):
                    attempts = 0
                    while attempts < 100:
                        new_arch = self.random_unique_sample()
                        attempts += 1
                        
                        if new_arch not in self.visited:
                            self.visited.add(new_arch)
                            acc = self.evaluate_arch(new_arch)
                            new_population.append([acc, new_arch])
                            random_count += 1
                            break
                        
                    if attempts >= 100:
                        self.logger.warning(f"Failed to find new architecture after {attempts} random sampling attempts, accepting a potentially visited architecture")
                        new_arch = self.arch_manager.random_sample()
                        if new_arch not in self.visited:
                            self.visited.add(new_arch)
                            acc = self.evaluate_arch(new_arch)
                            new_population.append([acc, new_arch])
                            random_count += 1
            
            self.logger.info(f"Architecture sources: mutation={mutation_count}, crossover={crossover_count}, random={random_count}")
            self.logger.info(f"New population size: {len(new_population)} (target: {population_size})")
            
            population = new_population
            
            if len(population) < population_size:
                self.logger.warning(f"Warning: Population size ({len(population)}) is less than target size ({population_size})")
        self.logger.info(f"\nEvolution search completed!")
        self.logger.info(f"Total explored architectures: {self.total_explored}")
        self.logger.info(f"Best architecture found at #{self.best_found_at} (after {self.best_found_at/self.total_explored:.2%} of total):")
        self.logger.info(f"Accuracy: {self.best_acc:.4f}")
        self.logger.info(f"Architecture: {self.best_arch}")
            
        return best_valids, [self.best_acc, self.best_arch], self.best_found_at


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Search for high-performance architectures in NAS-Bench-201 using evolution')
    parser.add_argument('--iterations', type=int, default=9 help='Number of search iterations')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                        help='Dataset to optimize architecture for')
    parser.add_argument('--metric', type=str, default='test', choices=['test', 'valid'],
                        help='Which metric to use for evaluation (test or validation accuracy)')
    parser.add_argument('--population', type=int, default=10, help='Population size')
    parser.add_argument('--parent_ratio', type=float, default=0.5, help='Ratio of parents to keep')
    parser.add_argument('--mutation_ratio', type=float, default=0.5, help='Ratio of mutation vs crossover')
    parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility')
    args = parser.parse_args()
    
    current_file = os.path.basename(__file__).split('.')[0]
    log_dir = f"search_logs/{current_file}/{args.dataset}-{args.metric}"
    os.makedirs(log_dir, exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir, f"seed{args.seed}_{args.dataset}_{args.metric}_p{args.population}_e{args.iterations}_pr{args.parent_ratio}_{timestamp}.log")),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger("NAS-Evolution")
    
    finder = EvolutionFinderNASBench201(
        dataset=args.dataset,
        logger=logger,
        max_time_budget=args.iterations,
        population_size=args.population,
        parent_ratio=args.parent_ratio,
        mutation_ratio=args.mutation_ratio,
        seed=args.seed,
        metric=args.metric
    )
    
    print(f"Starting evolution search on {args.dataset} with {args.metric} metric, iterations: {args.iterations}")
    best_valids, best_info, best_found_at = finder.run_evolution_search()
    
    print("\nBest architecture:")
    print(f"Best architecture found at sample: #{best_found_at} (after {best_found_at/finder.total_explored:.2%} of total)")
    print(f"Architecture string: {best_info[1]}")
    print(f"Accuracy on {args.dataset} ({args.metric}): {best_info[0]:.4f}")


if __name__ == "__main__":
    main()
